{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 加载包和处理数据\n",
    "1. 我们只需要保证数据都含有content这一列（当然，数据是我处理的，肯定是有的）\n",
    "2. 使用`glob`包，获得所有的数据的路径\n",
    "3. 我们使用`random`包，从所有的文件路径中，随机找50个数据路径，作为训练集合的使用"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, DatasetDict\n",
    "from glob import glob\n",
    "import random\n",
    "random.seed(42)\n",
    "\n",
    "all_file_list = glob(pathname=\"gpt2_data/*/**\")\n",
    "test_file_list = random.sample(all_file_list, 50)\n",
    "train_file_list = [i for i in all_file_list if i not in test_file_list]\n",
    "\n",
    "len(train_file_list), len(test_file_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 创建数据\n",
    "1. 只要将路径放到一个字典里面。dict的key分别为`train`、`valid`，他们对应的值就是文件路径列表即可"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_datasets =load_dataset(\"csv\",data_files={'train':train_file_list,'valid':test_file_list}, cache_dir=\"cache_data\")\n",
    "\n",
    "raw_datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Tokenizer\n",
    "1. Tokenizer 是最关键的一步，因为我们处理的是中文，因此使用`bert_base_chinese`就足够了\n",
    "2. 如果你的语料里面有别的语言，你也可以使用多语言。这个都无所谓的。只要保证你使用的Tokenizer能覆盖你的数据即可\n",
    "3. `context_length = 512`设置你的每一个文本的最长长度，我这里设置的是512，如果你的显卡显寸小，那你可以改小一点，比如128。但是多出来的数据，并不是说直接截断不要了，而是按照`context_length`长度，不断的对文本进行截断，大概就像是下面这样的：\n",
    "\n",
    "<img src=\"https://huggingface.co/datasets/huggingface-course/documentation-images/resolve/main/en/chapter7/chunking_texts.svg\"/>\n",
    "\n",
    "\n",
    "4. 对于`gpt2`模型，需要告诉模型一句话从哪里开始，从哪里结束。因此我们需要设置`bos_token`、`eos_token`、`unk_token`\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig\n",
    "\n",
    "context_length = 512\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")\n",
    "\n",
    "outputs = tokenizer(\n",
    "    raw_datasets[\"train\"][:2][\"content\"],\n",
    "    truncation=True,\n",
    "    max_length=context_length,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_length=True,\n",
    ")\n",
    "\n",
    "print(f\"Input IDs length: {len(outputs['input_ids'])}\")\n",
    "print(f\"Input chunk lengths: {(outputs['length'])}\")\n",
    "print(f\"Chunk mapping: {outputs['overflow_to_sample_mapping']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.add_special_tokens(special_tokens_dict={'bos_token': '<|endoftext|>',\n",
    " 'eos_token': '<|endoftext|>',\n",
    " 'unk_token': '<|endoftext|>'})"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 说明\n",
    "1. 其实，到这里，我做的东西基本上就结束了。\n",
    "2. 只要查看这个链接，说明的更加清楚：[https://huggingface.co/course/zh-CN/chapter7/6?fw=pt](https://huggingface.co/course/zh-CN/chapter7/6?fw=pt),我也就是模仿这个链接的。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(element):\n",
    "    outputs = tokenizer(\n",
    "        element[\"content\"],\n",
    "        truncation=True,\n",
    "        max_length=context_length,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_length=True,\n",
    "    )\n",
    "    input_batch = []\n",
    "    for length, input_ids in zip(outputs[\"length\"], outputs[\"input_ids\"]):\n",
    "        if length == context_length:\n",
    "            input_batch.append(input_ids)\n",
    "    return {\"input_ids\": input_batch}\n",
    "\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(\n",
    "    tokenize, batched=True, remove_columns=raw_datasets[\"train\"].column_names\n",
    ")\n",
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 模型\n",
    "1. 这里的`gpt2`模型，可不是使用别人训练好的，就是一个gpt2配置，因为我们要使用这个从头训练一个新的`gpt2`"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig\n",
    "\n",
    "config = AutoConfig.from_pretrained(\n",
    "    \"gpt2\",\n",
    "    vocab_size=len(tokenizer),\n",
    "    n_ctx=context_length,\n",
    "    bos_token_id=tokenizer.bos_token_id,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GPT2LMHeadModel(config)\n",
    "model_size = sum(t.numel() for t in model.parameters())\n",
    "print(f\"GPT-2 size: {model_size/1000**2:.1f}M parameters\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = data_collator([tokenized_datasets[\"train\"][i] for i in range(5)])\n",
    "for key in out:\n",
    "    print(f\"{key} shape: {out[key].shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir=\"chinese_gpt2_big\",\n",
    "    per_device_train_batch_size=20,\n",
    "    per_device_eval_batch_size=16,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=2_000,\n",
    "    logging_steps=2_000,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=2,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1_000,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=2_000,\n",
    "    fp16=True,\n",
    "    push_to_hub=False,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"valid\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "          "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "110bc624a448454d574a0cd6cc76359fd86f75739e493913b3d71c2e04f2ffdb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
